{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Quantized MLP on UNSW-NB15 with Brevitas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font color=\"red\">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for \"latency hiding\".</font>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will show how to create, train and export a quantized Multi Layer Perceptron (MLP) with quantized weights and activations with [Brevitas](https://github.com/Xilinx/brevitas).\n",
    "Specifically, the task at hand will be to label network packets as normal or suspicious (e.g. originating from an attacker, virus, malware or otherwise) by training on a quantized variant of the UNSW-NB15 dataset. \n",
    "\n",
    "**You won't need a GPU to train the neural net.** This MLP will be small enough to train on a modern x86 CPU, so no GPU is required to follow this tutorial  Alternatively, we provide pre-trained parameters for the MLP if you want to skip the training entirely.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A quick introduction to the task and the dataset\n",
    "\n",
    "*The task:* The goal of [*network intrusion detection*](https://ieeexplore.ieee.org/abstract/document/283931) is to identify, preferably in real time, unauthorized use, misuse, and abuse of computer systems by both system insiders and external penetrators. This may be achieved by a mix of techniques, and machine-learning (ML) based techniques are increasing in popularity. \n",
    "\n",
    "*The dataset:* Several datasets are available for use in ML-based methods for intrusion detection.\n",
    "The **UNSW-NB15** is one such dataset created by the Australian Centre for Cyber Security (ACCS) to provide a comprehensive network based data set which can reflect modern network traffic scenarios. You can find more details about the dataset on [its homepage](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/).\n",
    "\n",
    "*Performance considerations:* FPGAs are commonly used for implementing high-performance packet processing systems that still provide a degree of programmability. To avoid introducing bottlenecks on the network, the DNN implementation must be capable of detecting malicious ones at line rate, which can be millions of packets per second, and is expected to increase further as next-generation networking solutions provide increased\n",
    "throughput. This is a good reason to consider FPGA acceleration for this particular use-case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "-------------\n",
    "\n",
    "* [Load the UNSW_NB15 Dataset](#load_dataset) \n",
    "* [Define a PyTorch Device](#define_pytorch_device)\n",
    "* [Define the Quantized MLP Model](#define_quantized_mlp)\n",
    "* [Define Train and Test  Methods](#train_test)\n",
    "    * [(Option 1) Train the Model from Scratch](#train_scratch)\n",
    "    * [(Option 2) Load Pre-Trained Parameters](#load_pretrained)\n",
    "* [Network Surgery Before Export](#network_surgery)\n",
    "* [Export to FINN-ONNX](#export_finn_onnx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the UNSW_NB15 Dataset <a id='load_dataset'></a>\n",
    "\n",
    "### Dataset Quantization <a id='dataset_qnt'></a>\n",
    "\n",
    "The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. Although we can choose a variety of different precisions for the input, [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf) have previously shown we can actually binarize the inputs and still get good (90%+) accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a binarized representation for the dataset by following the procedure defined by Murovic and Trost, which we repeat briefly here:\n",
    "\n",
    "* Original features have different formats ranging from integers, floating numbers to strings.\n",
    "* Integers, which for example represent a packet lifetime, are binarized with as many bits as to include the maximum value. \n",
    "* Another case is with features formatted as strings (protocols), which are binarized by simply counting the number of all different strings for each feature and coding them in the appropriate number of bits.\n",
    "* Floating-point numbers are reformatted into fixed-point representation.\n",
    "* In the end, each sample is transformed into a 593-bit wide binary vector. \n",
    "* All vectors are labeled as bad (0) or normal (1)\n",
    "\n",
    "Following Murovic and Trost's open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).\n",
    "\n",
    "<font color=\"red\">**Live FINN tutorial:** Downloading the original dataset and quantizing it can take some time, so we provide a download link to the pre-quantized version for your convenience. </font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! wget -O unsw_nb15_binarized.npz https://zenodo.org/record/4519767/files/unsw_nb15_binarized.npz?download=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can extract the binarized numpy arrays from the .npz archive and wrap them as a PyTorch `TensorDataset` as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "def get_preqnt_dataset(data_dir: str, train: bool):\n",
    "    unsw_nb15_data = np.load(data_dir + \"/unsw_nb15_binarized.npz\")\n",
    "    if train:\n",
    "        partition = \"train\"\n",
    "    else:\n",
    "        partition = \"test\"\n",
    "    part_data = unsw_nb15_data[partition].astype(np.float32)\n",
    "    part_data = torch.from_numpy(part_data)\n",
    "    part_data_in = part_data[:, :-1]\n",
    "    part_data_out = part_data[:, -1]\n",
    "    return TensorDataset(part_data_in, part_data_out)\n",
    "\n",
    "train_quantized_dataset = get_preqnt_dataset(\".\", True)\n",
    "test_quantized_dataset = get_preqnt_dataset(\".\", False)\n",
    "\n",
    "print(\"Samples in each set: train = %d, test = %s\" % (len(train_quantized_dataset), len(test_quantized_dataset))) \n",
    "print(\"Shape of one input sample: \" +  str(train_quantized_dataset[0][0].shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up DataLoader\n",
    "\n",
    "Following either option, we now have access to the quantized dataset. We will wrap the dataset in a PyTorch `DataLoader` for easier access in batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "batch_size = 1000\n",
    "\n",
    "# dataset loaders\n",
    "train_quantized_loader = DataLoader(train_quantized_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=False)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "for x,y in train_quantized_loader:\n",
    "    print(\"Input shape for 1 batch: \" + str(x.shape))\n",
    "    print(\"Label shape for 1 batch: \" + str(y.shape))\n",
    "    count += 1\n",
    "    if count == 1:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define a PyTorch Device <a id='define_pytorch_device'></a> \n",
    "\n",
    "GPUs can significantly speed-up training of deep neural networks. We check for availability of a GPU and if so define it as target device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Target device: \" + str(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>\n",
    "\n",
    "We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.\n",
    "For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).\n",
    "\n",
    "Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.\n",
    "\n",
    "In case you'd like to experiment with different quantization settings or topology parameters, we'll define all these topology settings as variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 593      \n",
    "hidden1 = 64      \n",
    "hidden2 = 64\n",
    "hidden3 = 64\n",
    "weight_bit_width = 2\n",
    "act_bit_width = 2\n",
    "num_classes = 1    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can define our MLP using the layer primitives provided by Brevitas:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brevitas.nn import QuantLinear, QuantReLU\n",
    "import torch.nn as nn\n",
    "\n",
    "# Setting seeds for reproducibility\n",
    "torch.manual_seed(0)\n",
    "\n",
    "model = nn.Sequential(\n",
    "      QuantLinear(input_size, hidden1, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden1),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden1, hidden2, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden2),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden2, hidden3, bias=True, weight_bit_width=weight_bit_width),\n",
    "      nn.BatchNorm1d(hidden3),\n",
    "      nn.Dropout(0.5),\n",
    "      QuantReLU(bit_width=act_bit_width),\n",
    "      QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width)\n",
    ")\n",
    "\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the MLP's output is not yet quantized. Even though we want the final output of our MLP to be a binary (0/1) value indicating the classification, we've only defined a single-neuron FC layer as the output. While training the network we'll pass that output through a sigmoid function as part of the loss criterion, which [gives better numerical stability](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html). Later on, after we're done training the network, we'll add a quantization node at the end before we export it to FINN."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Train and Test  Methods  <a id='train_test'></a>\n",
    "The train and test methods will use a `DataLoader`, which feeds the model with a new predefined batch of training data in each iteration, until the entire training data is fed to the model. Each repetition of this process is called an `epoch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, optimizer, criterion):\n",
    "    losses = []\n",
    "    # ensure model is in training mode\n",
    "    model.train()    \n",
    "    \n",
    "    for i, data in enumerate(train_loader, 0):        \n",
    "        inputs, target = data\n",
    "        inputs, target = inputs.to(device), target.to(device)\n",
    "        optimizer.zero_grad()   \n",
    "                \n",
    "        # forward pass\n",
    "        output = model(inputs.float())\n",
    "        loss = criterion(output, target.unsqueeze(1))\n",
    "        \n",
    "        # backward pass + run optimizer to update weights\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # keep track of loss value\n",
    "        losses.append(loss.data.cpu().numpy()) \n",
    "           \n",
    "    return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "def test(model, test_loader):    \n",
    "    # ensure model is in eval mode\n",
    "    model.eval() \n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "   \n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            inputs, target = data\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            output_orig = model(inputs.float())\n",
    "            # run the output through sigmoid\n",
    "            output = torch.sigmoid(output_orig)  \n",
    "            # compare against a threshold of 0.5 to generate 0/1\n",
    "            pred = (output.detach().cpu().numpy() > 0.5) * 1\n",
    "            target = target.cpu().float()\n",
    "            y_true.extend(target.tolist()) \n",
    "            y_pred.extend(pred.reshape(-1).tolist())\n",
    "        \n",
    "    return accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the QNN <a id=\"train_qnn\"></a>\n",
    "\n",
    "We provide two options for training below: you can opt for training the model from scratch (slower) or use a pre-trained model (faster). The first option will give more insight into how the training process works, while the second option will likely give better accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Option 1, slower) Train the Model from Scratch <a id=\"train_scratch\"></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we start training our MLP we need to define some hyperparameters. Moreover, in order to monitor the loss function evolution over epochs, we need to define a method for it. As mentioned earlier, we'll use a loss criterion which applies a sigmoid function during the training phase (`BCEWithLogitsLoss`). For the testing phase, we're manually computing the sigmoid and thresholding at 0.5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "lr = 0.001 \n",
    "\n",
    "def display_loss_plot(losses, title=\"Training loss\", xlabel=\"Iterations\", ylabel=\"Loss\"):\n",
    "    x_axis = [i for i in range(len(losses))]\n",
    "    plt.plot(x_axis,losses)\n",
    "    plt.title(title)\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss criterion and optimizer\n",
    "criterion = nn.BCEWithLogitsLoss().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "# Setting seeds for reproducibility\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "running_loss = []\n",
    "running_test_acc = []\n",
    "t = trange(num_epochs, desc=\"Training loss\", leave=True)\n",
    "\n",
    "for epoch in t:\n",
    "        loss_epoch = train(model, train_quantized_loader, optimizer,criterion)\n",
    "        test_acc = test(model, test_quantized_loader)\n",
    "        t.set_description(\"Training loss = %f test accuracy = %f\" % (np.mean(loss_epoch), test_acc))\n",
    "        t.refresh() # to show immediately the update           \n",
    "        running_loss.append(loss_epoch)\n",
    "        running_test_acc.append(test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "loss_per_epoch = [np.mean(loss_per_epoch) for loss_per_epoch in running_loss]\n",
    "display_loss_plot(loss_per_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_per_epoch = [np.mean(acc_per_epoch) for acc_per_epoch in running_test_acc]\n",
    "display_loss_plot(acc_per_epoch, title=\"Test accuracy\", ylabel=\"Accuracy [%]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test(model, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the Brevitas model to disk\n",
    "torch.save(model.state_dict(), \"state_dict_self-trained.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Option 2, faster) Load Pre-Trained Parameters <a id=\"load_pretrained\"></a>\n",
    "\n",
    "Instead of training from scratch, you can also use pre-trained parameters we provide here. These parameters should achieve ~91.9% test accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Make sure the model is on CPU before loading a pretrained state_dict\n",
    "model = model.cpu()\n",
    "\n",
    "# Load pretrained weights\n",
    "trained_state_dict = torch.load(\"state_dict.pth\")[\"models_state_dict\"][0]\n",
    "\n",
    "model.load_state_dict(trained_state_dict, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move the model back to it's target device\n",
    "model.to(device)\n",
    "\n",
    "# Test for accuracy\n",
    "test(model, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Why do these parameters give better accuracy vs training from scratch?** Even with the topology and quantization fixed, achieving good accuracy on a given dataset requires [*hyperparameter tuning*](https://towardsdatascience.com/hyperparameters-optimization-526348bb8e2d) and potentially running training for a long time. The \"training from scratch\" example above is only intended as a quick example, whereas the pretrained parameters are obtained from a longer training run using the [determined.ai](https://determined.ai/) platform for hyperparameter tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network Surgery Before Export <a id=\"network_surgery\"></a>\n",
    "\n",
    "Sometimes, it's desirable to make some changes to our trained network prior to export (this is known in general as \"network surgery\"). This depends on the model and is not generally necessary, but in this case we want to make a couple of changes to get better results with FINN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move the model to CPU before surgery\n",
    "model = model.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by padding the input. Our input vectors are 593-bit, which will make folding (parallelization) for the first layer a bit tricky since 593 is a prime number. So we'll pad the weight matrix of the first layer with seven 0-valued columns to work with an input size of 600 instead. When using the modified network we'll similarly provide inputs padded to 600 bits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "modified_model = deepcopy(model)\n",
    "\n",
    "W_orig = modified_model[0].weight.data.detach().numpy()\n",
    "W_orig.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# pad the second (593-sized) dimensions with 7 zeroes at the end\n",
    "W_new = np.pad(W_orig, [(0,0), (0,7)])\n",
    "W_new.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_model[0].weight.data = torch.from_numpy(W_new)\n",
    "modified_model[0].weight.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll modify the expected input/output ranges. In FINN, we prefer to work with bipolar {-1, +1} instead of binary {0, 1} values. To achieve this, we'll create a \"wrapper\" model that handles the pre/postprocessing as follows:\n",
    "\n",
    "* on the input side, we'll pre-process by (x + 1) / 2 in order to map incoming {-1, +1} inputs to {0, 1} ones which the trained network is used to. Since we're just multiplying/adding a scalar, these operations can be [*streamlined*](https://finn.readthedocs.io/en/latest/nw_prep.html#streamlining-transformations) by FINN and implemented with no extra cost.\n",
    "\n",
    "* on the output side, we'll add a binary quantizer which maps everthing below 0 to -1 and everything above 0 to +1. This is essentially the same behavior as the sigmoid we used earlier, except the outputs are bipolar instead of binary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brevitas.nn import QuantIdentity\n",
    "\n",
    "\n",
    "class CybSecMLPForExport(nn.Module):\n",
    "    def __init__(self, my_pretrained_model):\n",
    "        super(CybSecMLPForExport, self).__init__()\n",
    "        self.pretrained = my_pretrained_model\n",
    "        self.qnt_output = QuantIdentity(\n",
    "            quant_type='binary', \n",
    "            scaling_impl_type='const',\n",
    "            bit_width=1, min_val=-1.0, max_val=1.0)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # assume x contains bipolar {-1,1} elems\n",
    "        # shift from {-1,1} -> {0,1} since that is the\n",
    "        # input range for the trained network\n",
    "        x = (x + torch.tensor([1.0]).to(x.device)) / 2.0  \n",
    "        out_original = self.pretrained(x)\n",
    "        out_final = self.qnt_output(out_original)   # output as {-1,1}     \n",
    "        return out_final\n",
    "\n",
    "model_for_export = CybSecMLPForExport(modified_model)\n",
    "model_for_export.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_padded_bipolar(model, test_loader):    \n",
    "    # ensure model is in eval mode\n",
    "    model.eval() \n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "   \n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            inputs, target = data\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            # pad inputs to 600 elements\n",
    "            input_padded = torch.nn.functional.pad(inputs, (0,7,0,0))\n",
    "            # convert inputs to {-1,+1}\n",
    "            input_scaled = 2 * input_padded - 1\n",
    "            # run the model\n",
    "            output = model(input_scaled.float())\n",
    "            y_pred.extend(list(output.flatten().cpu().numpy()))\n",
    "            # make targets bipolar {-1,+1}\n",
    "            expected = 2 * target.float() - 1\n",
    "            expected = expected.cpu().numpy()\n",
    "            y_true.extend(list(expected.flatten()))\n",
    "        \n",
    "    return accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_padded_bipolar(model_for_export, test_quantized_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export to FINN-ONNX <a id=\"export_finn_onnx\" ></a>\n",
    "\n",
    "\n",
    "[ONNX](https://onnx.ai/) is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx).\n",
    "\n",
    "You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation. Note how we create a `QuantTensor` instance with dummy data to tell Brevitas how our inputs look like, which will be used to set the input quantization annotation on the exported model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import brevitas.onnx as bo\n",
    "from brevitas.quant_tensor import QuantTensor\n",
    "\n",
    "ready_model_filename = \"cybsec-mlp-ready.onnx\"\n",
    "input_shape = (1, 600)\n",
    "\n",
    "# create a QuantTensor instance to mark input as bipolar during export\n",
    "input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)\n",
    "input_a = 2 * input_a - 1\n",
    "scale = 1.0\n",
    "input_t = torch.from_numpy(input_a * scale)\n",
    "input_qt = QuantTensor(\n",
    "    input_t, scale=torch.tensor(scale), bit_width=torch.tensor(1.0), signed=True\n",
    ")\n",
    "\n",
    "#Move to CPU before export\n",
    "model_for_export.cpu()\n",
    "\n",
    "# Export to ONNX\n",
    "bo.export_finn_onnx(\n",
    "    model_for_export, export_path=ready_model_filename, input_t=input_qt\n",
    ")\n",
    "\n",
    "print(\"Model saved to %s\" % ready_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View the Exported ONNX in Netron\n",
    "\n",
    "Let's examine the exported ONNX model with [Netron](https://github.com/lutzroeder/netron), which is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties. Particular things of note:\n",
    "\n",
    "* The input tensor \"0\" is annotated with `quantization: finn_datatype: BIPOLAR`\n",
    "* The input preprocessing (x + 1) / 2 is exported as part of the network (initial `Add` and `Div` layers)\n",
    "* Brevitas `QuantLinear` layers are exported to ONNX as `MatMul`. We've exported the padded version; shape of the first MatMul node's weight parameter is 600x64\n",
    "* The weight parameters (second inputs) for MatMul nodes are annotated with `quantization: finn_datatype: INT2`\n",
    "* The quantized activations are exported as `MultiThreshold` nodes with `domain=qonnx.custom_op.general`\n",
    "* There's a final `MultiThreshold` node with threshold=0 to produce the final bipolar output (this is the `qnt_output` from `CybSecMLPForExport`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.util.visualization import showInNetron\n",
    "\n",
    "showInNetron(ready_model_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## That's it! <a id=\"thats_it\" ></a>\n",
    "You created, trained and tested a quantized MLP that is ready to be loaded into FINN, congratulations! You can now proceed to the next notebook."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
